{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# k近邻算法教程"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. 添加依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris # 用来加载iris数据\n",
    "from sklearn.model_selection import train_test_split # 切分训练集和测试集\n",
    "from sklearn.metrics import accuracy_score # 计算分类数据的预测准确度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 数据加载和预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = load_iris()\n",
    "\n",
    "#print(iris)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                5.1               3.5                1.4               0.2   \n",
       "1                4.9               3.0                1.4               0.2   \n",
       "2                4.7               3.2                1.3               0.2   \n",
       "3                4.6               3.1                1.5               0.2   \n",
       "4                5.0               3.6                1.4               0.2   \n",
       "\n",
       "    class  \n",
       "0  setosa  \n",
       "1  setosa  \n",
       "2  setosa  \n",
       "3  setosa  \n",
       "4  setosa  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(data = iris.data, columns = iris.feature_names)\n",
    "# 加入分类标签数据，target\n",
    "df['class'] = iris.target\n",
    "df['class'] = df['class'].map({0: iris.target_names[0], 1: iris.target_names[1], 2: iris.target_names[2]})\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>5.843333</td>\n",
       "      <td>3.057333</td>\n",
       "      <td>3.758000</td>\n",
       "      <td>1.199333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.828066</td>\n",
       "      <td>0.435866</td>\n",
       "      <td>1.765298</td>\n",
       "      <td>0.762238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>4.300000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.100000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>5.100000</td>\n",
       "      <td>2.800000</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>0.300000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>5.800000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>4.350000</td>\n",
       "      <td>1.300000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>6.400000</td>\n",
       "      <td>3.300000</td>\n",
       "      <td>5.100000</td>\n",
       "      <td>1.800000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>7.900000</td>\n",
       "      <td>4.400000</td>\n",
       "      <td>6.900000</td>\n",
       "      <td>2.500000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       sepal length (cm)  sepal width (cm)  petal length (cm)  \\\n",
       "count         150.000000        150.000000         150.000000   \n",
       "mean            5.843333          3.057333           3.758000   \n",
       "std             0.828066          0.435866           1.765298   \n",
       "min             4.300000          2.000000           1.000000   \n",
       "25%             5.100000          2.800000           1.600000   \n",
       "50%             5.800000          3.000000           4.350000   \n",
       "75%             6.400000          3.300000           5.100000   \n",
       "max             7.900000          4.400000           6.900000   \n",
       "\n",
       "       petal width (cm)  \n",
       "count        150.000000  \n",
       "mean           1.199333  \n",
       "std            0.762238  \n",
       "min            0.100000  \n",
       "25%            0.300000  \n",
       "50%            1.300000  \n",
       "75%            1.800000  \n",
       "max            2.500000  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150, 4) (150, 1)\n"
     ]
    }
   ],
   "source": [
    "x = iris.data\n",
    "y = iris.target.reshape(-1,1)\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(105, 4) (105, 1)\n",
      "(45, 4) (45, 1)\n"
     ]
    }
   ],
   "source": [
    "# 划分训练集和测试集\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42, stratify = y)\n",
    "\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 算法实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 距离函数\n",
    "def l1_dist(a, b):\n",
    "    return np.sum(np.abs(a-b), axis=1)\n",
    "\n",
    "def l2_dist(a, b):\n",
    "    return np.sqrt(np.sum((a-b)**2, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分类器实现\n",
    "\n",
    "class kNN(object):\n",
    "    # 初始化，类的构造方法\n",
    "    def __init__(self, k_neighbors = 1, disc_func = l1_dist):\n",
    "        self.k_neighbors = k_neighbors\n",
    "        self.disc_func = disc_func\n",
    "    \n",
    "    def fit(self, x, y):\n",
    "        self.x_train = x\n",
    "        self.y_train = y\n",
    "        \n",
    "    def predict(self, test):\n",
    "        # 预测数组初始化为0\n",
    "        y_pred = np.zeros((test.shape[0],1), dtype = self.y_train.dtype)\n",
    "        \n",
    "        for i,x_test in enumerate(test):\n",
    "            # 计算距离矩阵\n",
    "            distances = self.disc_func(self.x_train, x_test)\n",
    "            \n",
    "            # 按距离大小排序,取出索引值\n",
    "            nn_index = np.argsort(distances)\n",
    "            \n",
    "            # 取前k个值索引，计算分类频率\n",
    "            nn_pred = self.y_train[nn_index[:self.k_neighbors]].ravel()\n",
    "            y_pred[i] = np.argmax(np.bincount(nn_pred))\n",
    "            \n",
    "        return y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "分类准确率: 93.33333%\n"
     ]
    }
   ],
   "source": [
    "knn = kNN(k_neighbors = 9)\n",
    "knn.fit(x_train, y_train)\n",
    "\n",
    "y_pred = knn.predict(x_test)\n",
    "\n",
    "print(\"分类准确率: {:.5f}%\".format(accuracy_score(y_test, y_pred)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>k</th>\n",
       "      <th>距离函数</th>\n",
       "      <th>分类准确率</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>91.111111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>9</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>11</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>13</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>15</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>17</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>19</td>\n",
       "      <td>l1_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>3</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>5</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>97.777778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>7</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>9</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>11</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>13</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>93.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>15</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>17</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>l2_dist</td>\n",
       "      <td>95.555556</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     k     距离函数      分类准确率\n",
       "0    1  l1_dist  91.111111\n",
       "1    3  l1_dist  93.333333\n",
       "2    5  l1_dist  93.333333\n",
       "3    7  l1_dist  93.333333\n",
       "4    9  l1_dist  93.333333\n",
       "5   11  l1_dist  95.555556\n",
       "6   13  l1_dist  95.555556\n",
       "7   15  l1_dist  95.555556\n",
       "8   17  l1_dist  93.333333\n",
       "9   19  l1_dist  95.555556\n",
       "10   1  l2_dist  93.333333\n",
       "11   3  l2_dist  95.555556\n",
       "12   5  l2_dist  97.777778\n",
       "13   7  l2_dist  95.555556\n",
       "14   9  l2_dist  95.555556\n",
       "15  11  l2_dist  93.333333\n",
       "16  13  l2_dist  93.333333\n",
       "17  15  l2_dist  95.555556\n",
       "18  17  l2_dist  95.555556\n",
       "19  19  l2_dist  95.555556"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn = kNN()\n",
    "knn.fit(x_train, y_train)\n",
    "\n",
    "result_list = []\n",
    "for p in [1,2]:\n",
    "    knn.disc_func = l1_dist if p == 1 else l2_dist\n",
    "    \n",
    "    for k in range(1,20,2):\n",
    "        knn.k_neighbors = k\n",
    "        y_pred = knn.predict(x_test)\n",
    "        acc = accuracy_score(y_test, y_pred)*100\n",
    "        result_list.append([k, 'l1_dist' if p == 1 else 'l2_dist', acc])\n",
    "        \n",
    "df = pd.DataFrame(result_list, columns=['k', '距离函数', '分类准确率'])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
